{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "device = torch.device(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "ConnectionError",
     "evalue": "Couldn't reach 'jitx/Methods2Test_java_unit_test_code' on the Hub (LocalEntryNotFoundError)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mConnectionError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[6], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m ds \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mjitx/Methods2Test_java_unit_test_code\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdata/\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\datasets\\load.py:2129\u001b[0m, in \u001b[0;36mload_dataset\u001b[1;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, keep_in_memory, save_infos, revision, token, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[0;32m   2124\u001b[0m verification_mode \u001b[38;5;241m=\u001b[39m VerificationMode(\n\u001b[0;32m   2125\u001b[0m     (verification_mode \u001b[38;5;129;01mor\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mBASIC_CHECKS) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m save_infos \u001b[38;5;28;01melse\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS\n\u001b[0;32m   2126\u001b[0m )\n\u001b[0;32m   2128\u001b[0m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[1;32m-> 2129\u001b[0m builder_instance \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   2130\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2131\u001b[0m \u001b[43m    \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2132\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2133\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2134\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2135\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2136\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2137\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2138\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2139\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2140\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2141\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2142\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m   2143\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   2144\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   2146\u001b[0m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[0;32m   2147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\datasets\\load.py:1849\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[1;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, storage_options, trust_remote_code, _require_default_config_name, **config_kwargs)\u001b[0m\n\u001b[0;32m   1847\u001b[0m     download_config \u001b[38;5;241m=\u001b[39m download_config\u001b[38;5;241m.\u001b[39mcopy() \u001b[38;5;28;01mif\u001b[39;00m download_config \u001b[38;5;28;01melse\u001b[39;00m DownloadConfig()\n\u001b[0;32m   1848\u001b[0m     download_config\u001b[38;5;241m.\u001b[39mstorage_options\u001b[38;5;241m.\u001b[39mupdate(storage_options)\n\u001b[1;32m-> 1849\u001b[0m dataset_module \u001b[38;5;241m=\u001b[39m \u001b[43mdataset_module_factory\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1850\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1851\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1852\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1853\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1854\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1855\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1856\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1857\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1858\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_require_default_config_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1859\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_require_custom_configs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1860\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1861\u001b[0m \u001b[38;5;66;03m# Get dataset builder class from the processing script\u001b[39;00m\n\u001b[0;32m   1862\u001b[0m builder_kwargs \u001b[38;5;241m=\u001b[39m dataset_module\u001b[38;5;241m.\u001b[39mbuilder_kwargs\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\datasets\\load.py:1731\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[1;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, cache_dir, trust_remote_code, _require_default_config_name, _require_custom_configs, **download_kwargs)\u001b[0m\n\u001b[0;32m   1726\u001b[0m                 \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   1727\u001b[0m                     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[0;32m   1728\u001b[0m                         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find any data file at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrelative_to_absolute_path(path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   1729\u001b[0m                         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m on the Hugging Face Hub either: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(e1)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   1730\u001b[0m                     ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m-> 1731\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m e1 \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1732\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m trust_remote_code:\n\u001b[0;32m   1733\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[0;32m   1734\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find a dataset script at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrelative_to_absolute_path(combined_path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m or any data file in the same directory.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   1735\u001b[0m     )\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\datasets\\load.py:1618\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[1;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, cache_dir, trust_remote_code, _require_default_config_name, _require_custom_configs, **download_kwargs)\u001b[0m\n\u001b[0;32m   1609\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m LocalEntryNotFoundError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m   1610\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[0;32m   1611\u001b[0m         e\u001b[38;5;241m.\u001b[39m__cause__,\n\u001b[0;32m   1612\u001b[0m         (\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1616\u001b[0m         ),\n\u001b[0;32m   1617\u001b[0m     ):\n\u001b[1;32m-> 1618\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt reach \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m on the Hub (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[0;32m   1619\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   1620\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m\n",
      "\u001b[1;31mConnectionError\u001b[0m: Couldn't reach 'jitx/Methods2Test_java_unit_test_code' on the Hub (LocalEntryNotFoundError)"
     ]
    }
   ],
   "source": [
    "# %pip install datasets  -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n",
    "from datasets import load_dataset\n",
    "ds = load_dataset(\"jitx/Methods2Test_java_unit_test_code\", cache_dir=\"data/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = GPT2LMHeadModel.from_pretrained(\"ayoub-edh/Finetuned_SANTACODER_Java_Unit_Test_Generator\")\n",
    "tokenizer = GPT2Tokenizer.from_pretrained(\"ayoub-edh/Finetuned_SANTACODER_Java_Unit_Test_Generator\")\n",
    "model = model.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 1\n",
    "print(ds['train'][idx]['target'])\n",
    "print(ds['train'][idx]['src_fm'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_text = ds['train'][10]['src_fm']\n",
    "print(input_text)\n",
    "# input_text = \"Hello, how are you?\"\n",
    "inputs = tokenizer.encode(input_text, return_tensors='pt').to(device=device)\n",
    "print(inputs)\n",
    "\n",
    "de = tokenizer.decode(inputs[0], skip_special_tokens=True)\n",
    "print(de)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 6. 训练循环\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "optimizer = optim.Adam(model.parameters(), lr=5e-5)  # 使用Adam优化器，学习率为5e-5\n",
    "\n",
    "num_epochs = 100  # 设置训练周期数\n",
    "model.train()  # 将模型设置为训练模式\n",
    "voc_size = 500\n",
    "for epoch in range(num_epochs):\n",
    "    for i in range(voc_size):\n",
    "        qa = [ds['train'][i]['src_fm'], ds['train'][i]['target']]\n",
    "        inputs = tokenizer(ds['train'][i]['src_fm'], truncation=True, padding=True, return_tensors='pt')\n",
    "        input_ids = inputs.input_ids.to(device=device)\n",
    "        optimizer.zero_grad()  # 清空梯度\n",
    "        output = model(input_ids, labels=input_ids)  # 前向传播计算损失\n",
    "        loss = output.loss\n",
    "        loss.backward()  # 反向传播计算梯度\n",
    "        optimizer.step()  # 更新参数\n",
    "    print(f\"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}\")  # 打印当前周期的损失"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model.generate(inputs, max_length=200, num_return_sequences=1)\n",
    "# 解码生成的文本\n",
    "generated_text = tokenizer.decode(output[0], skip_special_tokens=True)\n",
    "print(generated_text)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
